{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>     <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/estimator\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a>   </td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab で実行</a>   </td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示</a>   </td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/estimator.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a>   </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEinLJt2Uowq"
      },
      "source": [
        "このドキュメントでは、`tf.estimator` という高位 TensorFlow API を紹介します。Estimator は以下のアクションをカプセル化します。\n",
        "\n",
        "- トレーニング\n",
        "- 評価\n",
        "- 予測\n",
        "- 配信向けエクスポート\n",
        "\n",
        "TensorFlow は、事前に作成された複数の Estimator を実装します。カスタムの Estimator は依然としてサポートされていますが、主に下位互換性の対策としてサポートされているため、**新しいコードでは、カスタム Estimator を使用してはいけません**。事前に作成された Estimator とカスタム Estimator はすべて、`tf.estimator.Estimator` クラスに基づくクラスです。\n",
        "\n",
        "簡単な例については、[Estimator チュートリアル](../tutorials/estimator/linear.ipynb)を試してください。API デザインの概要については、[ホワイトペーパー](https://arxiv.org/abs/1708.02637)をご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KLdnqg4G2bmz"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cXRQ6mRM5gk0"
      },
      "outputs": [],
      "source": [
        "! pip install -U tensorflow_datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J_-C9ty22dkD"
      },
      "outputs": [],
      "source": [
        "import tempfile\n",
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wg5zbBliQvNL"
      },
      "source": [
        "## メリット\n",
        "\n",
        "`tf.keras.Model` と同様に、`estimator` はモデルレベルの抽象です。`tf.estimator` は、`tf.keras` 向けに現在開発段階にある以下の機能を提供しています。\n",
        "\n",
        "- パラメーターサーバーベースのトレーニング\n",
        "- [TFX](http://tensorflow.org/tfx) の完全統合"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQ8fQYt_VD5E"
      },
      "source": [
        "## Estimator の機能\n",
        "\n",
        "Estimator には以下のメリットがあります。\n",
        "\n",
        "- Estimator ベースのモデルは、モデルを変更することなくローカルホストまたは分散マルチサーバー環境で実行できます。さらに、モデルをコーディングし直すことなく、CPU、GPU、または TPU で実行できます。\n",
        "- Estimator では、次を実行する方法とタイミングを制御する安全な分散型トレーニングループを使用できます。\n",
        "    - データの読み込み\n",
        "    - 例外の処理\n",
        "    - チェックポイントファイルの作成と障害からの復旧\n",
        "    - TensorBoard 用のサマリーの保存\n",
        "\n",
        "Estimator を使ってアプリケーションを記述する場合、データ入力パイプラインとモデルを分離する必要があります。分離することで、異なるデータセットを伴う実験を単純化することができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQ2PsufpgIpM"
      },
      "source": [
        "## 事前作成済み Estimator を使用する\n",
        "\n",
        "既成の Estimator を使うと、基本の TensorFlow API より非常に高い概念レベルで作業することができます。Estimator がすべての「配管作業」を処理してくれるため、計算グラフやセッションの作成などに気を回す必要がありません。さらに、事前作成済みの Estimator では、コード変更を最小限に抑えて多様なモデルアーキテクチャを使った実験を行えます。たとえば `tf.estimator.DNNClassifier` は、密度の高いフィードフォワードのニューラルネットワークに基づく分類モデルをトレーニングする事前作成済みの Estimator クラスです。\n",
        "\n",
        "事前作成済み Estimator に依存する TensorFlow プログラムは、通常、次の 4 つのステップで構成されています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mIJPPe26gQpF"
      },
      "source": [
        "### 1. 入力関数を作成する\n",
        "\n",
        "たとえば、トレーニングセットをインポートする関数とテストセットをインポートする関数を作成する場合、Estimator は入力が次の 2 つのオブジェクトのペアとしてフォーマットされていることを期待します。\n",
        "\n",
        "- 特徴名のキーと対応する特徴データを含むテンソル（または SparseTensors）の値で構成されるディクショナリ\n",
        "- 1 つ以上のラベルを含むテンソル\n",
        "\n",
        "`input_fn` は上記のフォーマットのペアを生成する `tf.data.Dataset` を返します。\n",
        "\n",
        "たとえば、次のコードは Titanic データセットの `train.csv` ファイルから `tf.data.Dataset` を構築します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7fl_C5d6hEl3"
      },
      "outputs": [],
      "source": [
        "def train_input_fn():\n",
        "  titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
        "  titanic = tf.data.experimental.make_csv_dataset(\n",
        "      titanic_file, batch_size=32,\n",
        "      label_name=\"survived\")\n",
        "  titanic_batches = (\n",
        "      titanic.cache().repeat().shuffle(500)\n",
        "      .prefetch(tf.data.experimental.AUTOTUNE))\n",
        "  return titanic_batches"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CjyrQGb3mCcp"
      },
      "source": [
        "`input_fn` は、`tf.Graph` で実行し、グラフテンソルを含む `(features_dics, labels)` ペアを直接返すこともできますが、定数を返すといった単純なケースではない場合に、エラーが発生しやすくなります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yJYjWUMxgTnq"
      },
      "source": [
        "### 2. 特徴量カラムを定義する\n",
        "\n",
        "`tf.feature_column` は、特徴量名、その型、およびすべての入力前処理を特定します。\n",
        "\n",
        "たとえば、次のスニペットは 3 つの特徴量カラムを作成します。\n",
        "\n",
        "- 最初の特徴量カラムは、浮動小数点数の入力として直接 `age` 特徴量を使用します。\n",
        "- 2 つ目の特徴量カラムは、カテゴリカル入力として `class` 特徴量を使用します。\n",
        "- 3 つ目の特徴量カラムは、カテゴリカル入力として `embark_town` を使用しますが、オプションを列挙する必要がないように、またオプション数を設定するために、`hashing trick` を使用します。\n",
        "\n",
        "詳細については、[特徴量カラムのチュートリアル](https://www.tensorflow.org/tutorials/keras/feature_columns)をご覧ください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFd8Dnrmhjhr"
      },
      "outputs": [],
      "source": [
        "age = tf.feature_column.numeric_column('age')\n",
        "cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) \n",
        "embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UIjqAozjgXdr"
      },
      "source": [
        "### 3. 関連する事前作成済み Estimator をインスタンス化する\n",
        "\n",
        "`LinearClassifier` という事前作成済み Estimator のインスタンス化の例を次に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CDOx6lZVoVB8"
      },
      "outputs": [],
      "source": [
        "model_dir = tempfile.mkdtemp()\n",
        "model = tf.estimator.LinearClassifier(\n",
        "    model_dir=model_dir,\n",
        "    feature_columns=[embark, cls, age],\n",
        "    n_classes=2\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QGl9oYuFoYj6"
      },
      "source": [
        "詳細については、[線形分類器のチュートリアル](https://www.tensorflow.org/tutorials/estimator/linear)をご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sXNBeY-oVxGQ"
      },
      "source": [
        "### 4. トレーニング、評価、または推論メソッドを呼び出す\n",
        "\n",
        "すべての Estimator には、`train`、`evaluate`、および `predict` メソッドがあります。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iGaJKkmVBgo2"
      },
      "outputs": [],
      "source": [
        "model = model.train(input_fn=train_input_fn, steps=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CXkivCNq0vfH"
      },
      "outputs": [],
      "source": [
        "result = model.evaluate(train_input_fn, steps=10)\n",
        "\n",
        "for key, value in result.items():\n",
        "  print(key, \":\", value)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CPLD8n4CLVi_"
      },
      "outputs": [],
      "source": [
        "for pred in model.predict(train_input_fn):\n",
        "  for key, value in pred.items():\n",
        "    print(key, \":\", value)\n",
        "  break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cbmrm9pFg5vo"
      },
      "source": [
        "### 事前作成済み Estimator のメリット\n",
        "\n",
        "事前作成済み Estimator は、次のようなベストプラクティスをエンコードするため、さまざまなメリットがあります。\n",
        "\n",
        "- さまざまな部分の計算グラフをどこで実行するかを決定し、単一のマシンまたはクラスタに戦略を実装するためのベストプラクティス。\n",
        "- イベント（要約）の書き込みと普遍的に役立つ要約のベストプラクティス。\n",
        "\n",
        "事前作成済み Estimator を使用しない場合は、上記の特徴量を独自に実装する必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oIaPjYgnZdn6"
      },
      "source": [
        "## カスタム Estimator\n",
        "\n",
        "事前作成済みかカスタムかに関係なく、すべての Estimator の中核は、*モデル関数*の `model_fn` にあります。これは、トレーニング、評価、および予測に使用するグラフを構築するメソッドです。事前作成済み Estimator を使用する場合は、モデル関数はすでに実装されていますが、カスタム Estimator を使用する場合は、モデル関数を自分で記述する必要があります。\n",
        "\n",
        "> 注意: カスタム `model_fn` は 1.x スタイルのグラフモードでそのまま実行します。つまり、Eager execution はなく、依存関係の自動制御もないため、`tf.estimator` からカスタム `model_fn` に移行する必要があります。代替の API は `tf.keras` と `tf.distribute` です。トレーニングの一部に `Estimator` を使用する必要がある場合は、`tf.keras.estimator.model_to_estimator` コンバータを使用して `keras.Model` から `Estimator` を作成する必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P7aPNnXUbN4j"
      },
      "source": [
        "## Keras モデルから Estimator を作成する\n",
        "\n",
        "`tf.keras.estimator.model_to_estimator` を使用して、既存の Keras モデルを Estimator に変換できます。モデルコードを最新の状態に変更したくても、トレーニングパイプラインに Estimator が必要な場合に役立ちます。\n",
        "\n",
        "Keras MobileNet V2 モデルをインスタンス化し、トレーニングに使用する optimizer、loss、および metrics とともにモデルをコンパイルします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XE6NMcuGeDOP"
      },
      "outputs": [],
      "source": [
        "keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(\n",
        "    input_shape=(160, 160, 3), include_top=False)\n",
        "keras_mobilenet_v2.trainable = False\n",
        "\n",
        "estimator_model = tf.keras.Sequential([\n",
        "    keras_mobilenet_v2,\n",
        "    tf.keras.layers.GlobalAveragePooling2D(),\n",
        "    tf.keras.layers.Dense(1)\n",
        "])\n",
        "\n",
        "# Compile the model\n",
        "estimator_model.compile(\n",
        "    optimizer='adam',\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A3hcxzcEfYfX"
      },
      "source": [
        "コンパイルされた Keras モデルから `Estimator` を作成します。Keras モデルの初期化状態が、作成した `Estimator` に維持されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UCSSifirfyHk"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8jRNRVb_fzGT"
      },
      "source": [
        "派生した `Estimator` をほかの `Estimator` と同じように扱います。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rv9xJk51e1fB"
      },
      "outputs": [],
      "source": [
        "IMG_SIZE = 160  # All images will be resized to 160x160\n",
        "\n",
        "def preprocess(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = (image/127.5) - 1\n",
        "  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fw8OjwujVBkc"
      },
      "outputs": [],
      "source": [
        "def train_input_fn(batch_size):\n",
        "  data = tfds.load('cats_vs_dogs', as_supervised=True)\n",
        "  train_data = data['train']\n",
        "  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)\n",
        "  return train_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JMb0cuy0gbTi"
      },
      "source": [
        "トレーニングするには、Estimator の train 関数を呼び出します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4JsvMp8Jge80"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvr_rAzngY9v"
      },
      "source": [
        "同様に、評価するには、Estimator の evaluate 関数を呼び出します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kVNPqysQgYR2"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5HeTOvCYbjZb"
      },
      "source": [
        "詳細については、`tf.keras.estimator.model_to_estimator` のドキュメントを参照してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zGG1tOM0L6iM"
      },
      "source": [
        "## Estimator でオブジェクトベースのチェックポイントを保存する\n",
        "\n",
        "Estimator はデフォルトで、[チェックポイントガイド](checkpoint.ipynb)で説明したオブジェクトグラフではなく、変数名でチェックポイントを保存します。`tf.train.Checkpoint` は名前ベースのチェックポイントを読み取りますが、モデルの一部を Estimator の `model_fn` の外側に移動すると変数名が変わることがあります。上位互換性においては、オブジェクトベースのチェックポイントを保存すると、Estimator の内側でモデルをトレーニングし、外側でそれを使用することが容易になります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-8AMJeueNyoM"
      },
      "outputs": [],
      "source": [
        "import tensorflow.compat.v1 as tf_compat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W5JbCEUGY-Xo"
      },
      "outputs": [],
      "source": [
        "def toy_dataset():\n",
        "  inputs = tf.range(10.)[:, None]\n",
        "  labels = inputs * 5. + tf.range(5.)[None, :]\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    dict(x=inputs, y=labels)).repeat().batch(2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gTZbsIRCZnCU"
      },
      "outputs": [],
      "source": [
        "class Net(tf.keras.Model):\n",
        "  \"\"\"A simple linear model.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(Net, self).__init__()\n",
        "    self.l1 = tf.keras.layers.Dense(5)\n",
        "\n",
        "  def call(self, x):\n",
        "    return self.l1(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6fQsBzJQN2y"
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode):\n",
        "  net = Net()\n",
        "  opt = tf.keras.optimizers.Adam(0.1)\n",
        "  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n",
        "                             optimizer=opt, net=net)\n",
        "  with tf.GradientTape() as tape:\n",
        "    output = net(features['x'])\n",
        "    loss = tf.reduce_mean(tf.abs(output - features['y']))\n",
        "  variables = net.trainable_variables\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "  return tf.estimator.EstimatorSpec(\n",
        "    mode,\n",
        "    loss=loss,\n",
        "    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n",
        "                      ckpt.step.assign_add(1)),\n",
        "    # Tell the Estimator to save \"ckpt\" in an object-based format.\n",
        "    scaffold=tf_compat.train.Scaffold(saver=ckpt))\n",
        "\n",
        "tf.keras.backend.clear_session()\n",
        "est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n",
        "est.train(toy_dataset, steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tObYHnrrb_mL"
      },
      "source": [
        "その後、`tf.train.Checkpoint` は Estimator のチェックポイントをその `model_dir` から読み込むことができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q6IP3Y_wb-fs"
      },
      "outputs": [],
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "ckpt = tf.train.Checkpoint(\n",
        "  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n",
        "ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n",
        "ckpt.step.numpy()  # From est.train(..., steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dk5wWyuMpuHx"
      },
      "source": [
        "## Estimator の SavedModel\n",
        "\n",
        "Estimator は、[`tf.Estimator.export_saved_model`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_saved_model) によって SavedModel をエクスポートします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9KQq5qzpzbK"
      },
      "outputs": [],
      "source": [
        "input_column = tf.feature_column.numeric_column(\"x\")\n",
        "\n",
        "estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])\n",
        "\n",
        "def input_fn():\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    ({\"x\": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)\n",
        "estimator.train(input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y9qCa6J6FVS5"
      },
      "source": [
        "`Estimator` を保存するには、`serving_input_receiver` を作成する必要があります。この関数は、SavedModel が受け取る生データを解析する `tf.Graph` の一部を構築します。\n",
        "\n",
        "`tf.estimator.export` モジュールには、これらの `receivers` を構築するための関数が含まれています。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJ4PJ-Cl4060"
      },
      "source": [
        "次のコードは、`feature_columns` に基づき、[tf-serving](https://tensorflow.org/serving) と合わせて使用されることの多いシリアル化された `tf.Example` プロトコルバッファを受け入れるレシーバーを構築します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lnmsmGOQFPED"
      },
      "outputs": [],
      "source": [
        "tmpdir = tempfile.mkdtemp()\n",
        "\n",
        "serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "  tf.feature_column.make_parse_example_spec([input_column]))\n",
        "\n",
        "estimator_base_path = os.path.join(tmpdir, 'from_estimator')\n",
        "estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q7XtbLMDaie2"
      },
      "source": [
        "また、Python からモデルを読み込んで実行することも可能です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c_BUBBNB1UH9"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(estimator_path)\n",
        "\n",
        "def predict(x):\n",
        "  example = tf.train.Example()\n",
        "  example.features.feature[\"x\"].float_list.value.extend([x])\n",
        "  return imported.signatures[\"predict\"](\n",
        "    examples=tf.constant([example.SerializeToString()]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C1ylWZCQ1ahG"
      },
      "outputs": [],
      "source": [
        "print(predict(1.5))\n",
        "print(predict(3.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_IrCCm0-isqA"
      },
      "source": [
        "`tf.estimator.export.build_raw_serving_input_receiver_fn` を使用すると、`tf.train.Example` の代わりに生のテンソルを取る入力関数を作成することができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nO0hmFCRoIll"
      },
      "source": [
        "## Estimator を使った `tf.distribute.Strategy` の使用（制限サポート）\n",
        "\n",
        "`tf.estimator` は、もともと非同期パラメーターサーバー手法をサポートしていた分散型トレーニング TensorFlow API です。`tf.estimator` は現在では `tf.distribute.Strategy` をサポートするようになっています。`tf.estimator` を使用している場合は、コードを少し変更するだけで、分散型トレーニングに変更することができます。これにより、Estimator ユーザーは複数の GPU と複数のワーカーだけでなく、TPU でも同期分散型トレーニングを実行できるようになりましたが、Estimator でのこのサポートには制限があります。詳細については、以下に示す「[現在、何がサポートされていますか](#estimator_support)」セクションをご覧ください。\n",
        "\n",
        "Estimator での `tf.distribute.Strategy` の使用は、Keras の事例とわずかに異なります。`strategy.scope` を使用する代わりに、ストラテジーオブジェクトを Estimator の `RunConfig` に渡します。\n",
        "\n",
        "詳細については、[分散型トレーニングガイド](distributed_training.ipynb)をご覧ください。\n",
        "\n",
        "次は、事前に作成された Estimator `LinearRegressor` と `MirroredStrategy` を使ってこの動作を示すコードスニペットです。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oGFY5nW_B3YU"
      },
      "outputs": [],
      "source": [
        "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
        "config = tf.estimator.RunConfig(\n",
        "    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)\n",
        "regressor = tf.estimator.LinearRegressor(\n",
        "    feature_columns=[tf.feature_column.numeric_column('feats')],\n",
        "    optimizer='SGD',\n",
        "    config=config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n6eSfLN5RGY8"
      },
      "source": [
        "ここでは、事前に作成された Estimator が使用されていますが、同じコードはカスタム Estimator でも動作します。`train_distribute` はトレーニングの分散方法を判定し、`eval_distribute` は評価の分散方法を判定します。この点も、トレーニングと評価に同じストラテジーを使用する Keras と異なるところです。\n",
        "\n",
        "入力関数を使用して、この Estimator をトレーニングし、評価することができます。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ky2ve2PB3YP"
      },
      "outputs": [],
      "source": [
        "def input_fn():\n",
        "  dataset = tf.data.Dataset.from_tensors(({\"feats\":[1.]}, [1.]))\n",
        "  return dataset.repeat(1000).batch(10)\n",
        "regressor.train(input_fn=input_fn, steps=10)\n",
        "regressor.evaluate(input_fn=input_fn, steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgaU9xQSSk2x"
      },
      "source": [
        "Estimator と Keras のもう 1 つの違いとして強調すべき点は入力の処理方法です。Keras では、データセットの各バッチは複数のレプリカに自動的に分断されますが、Estimator の場合は、バッチの自動分断やワーカーをまたいで自動的にシャーディングすることもありません。ワーカーやデバイスでのデータの分散方法はユーザーが完全に制御するものであるため、`input_fn` を提供してデータの分散方法を指定する必要があります。\n",
        "\n",
        "`input_fn` はワーカー当たり一度呼び出されるため、ワーカー当たり 1 つのデータセットが与えられます。次に、そのデータセットの 1 つのバッチがそのワーカーの 1 つのレプリカに供給され、したがって、1 つのワーカーの N 個のレプリカに対して N 個のバッチが消費されることになります。言い換えると、`input_fn` が返すデータセットは、サイズ `PER_REPLICA_BATCH_SIZE` のバッチを提供するということです。ステップのグローバルバッチサイズは、`PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync` として取得することができます。\n",
        "\n",
        "マルチワーカートレーニングを行う場合は、データをワーカー間で分割するか、それぞれにランダムシードを使用してシャッフルする必要があります。これを行う方法の例は、「[Estimator を使ったマルチワーカートレーニング](../tutorials/distribute/multi_worker_with_estimator.ipynb)」を参照してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3ieQKfWZhhL"
      },
      "source": [
        "また、同様に、マルチワーカーとパラメーターサーバーストラテジーを使用することができます。コードは変わりませんが、`tf.estimator.train_and_evaluate` を使用し、クラスタで実行している各バイナリの `TF_CONFIG` 環境変数を設定する必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A_lvUsSLZzVg"
      },
      "source": [
        "<a name=\"estimator_support\"></a>\n",
        "\n",
        "### 現在、何がサポートされていますか？\n",
        "\n",
        "`TPUStrategy` を除くすべてのストラテジーを使った Estimator でのトレーニングのサポートには制限があります。基本的なトレーニングと評価は機能しますが、`v1.train.Scaffold` などの多数の高度な機能はまだ機能しません。また、この統合には多数のバグも存在する可能性があります。現時点では、Keras とカスタムトレーニングループのサポートに注力しているため、このサポートを積極的に改善する予定はありません。可能な限り、それらの API で `tf.distribute` を使用するようにしてください。\n",
        "\n",
        "トレーニング API | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | CentralStorageStrategy | ParameterServerStrategy\n",
        ":-- | :-- | :-- | :-- | :-- | :--\n",
        "Estimator API | 制限サポート | 未サポート | 制限サポート | 制限サポート | 制限サポート\n",
        "\n",
        "### 例とチュートリアル\n",
        "\n",
        "次は、Estimator によるさまざまなストラテジーの使用方法を示す、エンドツーエンドの例です。\n",
        "\n",
        "1. [Estimator を使ったマルチワーカートレーニングのチュートリアル](../tutorials/distribute/multi_worker_with_estimator.ipynb)には、MNIST データセットで `MultiWorkerMirroredStrategy` を使って複数のワーカーをトレーニングする方法が説明されています。\n",
        "2. Kubernetes テンプレートを使った `tensorflow/ecosystem` で[分散ストラテジーによってマルチワーカートレーニングを実行する](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy)エンドツーエンドの例。Keras モデルから始め、`tf.keras.estimator.model_to_estimator` API を使って Estimator に変換します。\n",
        "3. [ResNet50](https://github.com/tensorflow/models/blob/master/official/vision/image_classification/resnet_imagenet_main.py) の公式モデル。`MirroredStrategy` または `MultiWorkerMirroredStrategy` を使ってトレーニングできます。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "estimator.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
